import numpy as np
import pytest

from skfolio.utils.sorting import dominate, non_denominated_sort


@pytest.fixture(scope="module")
def fitnesses():
    fitnesses = np.array(
        [
            [7.86504305e-04, -8.53775798e-03, -3.30448716e-01],
            [6.39299739e-04, -9.67442239e-03, -4.09008624e-01],
            [8.99144614e-04, -1.09226358e-02, -3.24804728e-01],
            [7.46622864e-04, -1.01365913e-02, -3.31309040e-01],
            [5.13079290e-04, -1.11645845e-02, -5.33852966e-01],
            [7.61099204e-04, -9.02503956e-03, -3.33791019e-01],
            [7.04927046e-04, -8.54185825e-03, -3.27197641e-01],
            [6.78115601e-04, -9.62955488e-03, -3.79057217e-01],
            [6.27076918e-04, -9.60003232e-03, -4.41513832e-01],
            [5.77215579e-04, -8.12052597e-03, -3.32249662e-01],
            [7.32333261e-04, -1.07163659e-02, -4.28513243e-01],
            [8.43210270e-04, -9.03070993e-03, -3.17642285e-01],
            [7.12472371e-04, -9.29380419e-03, -4.01208974e-01],
            [8.13207541e-04, -9.09570675e-03, -3.16223672e-01],
            [9.34813845e-04, -9.26509832e-03, -2.77886021e-01],
            [5.66361092e-04, -9.58427141e-03, -4.27397224e-01],
            [5.34934243e-04, -9.10651021e-03, -4.19638941e-01],
            [5.78452868e-04, -9.62742021e-03, -3.70782176e-01],
            [7.28746213e-04, -9.44268878e-03, -3.06483870e-01],
            [8.40449465e-04, -8.99537604e-03, -2.55783398e-01],
            [9.20605110e-04, -8.76232439e-03, -2.88443115e-01],
            [6.10137463e-04, -8.88996094e-03, -3.76522498e-01],
            [7.08868374e-04, -1.08518708e-02, -3.95514108e-01],
            [6.21642937e-04, -9.26945071e-03, -3.40666106e-01],
            [6.80420670e-04, -8.17355948e-03, -3.29734066e-01],
            [7.41941783e-04, -9.45157028e-03, -3.77567364e-01],
            [7.01588026e-04, -8.72915381e-03, -3.34537404e-01],
            [6.99204929e-04, -8.91014227e-03, -3.02468646e-01],
            [8.38415986e-04, -9.31153319e-03, -3.57844023e-01],
            [4.80098343e-04, -1.03979888e-02, -4.48348958e-01],
            [6.90745584e-04, -1.02130498e-02, -4.09907473e-01],
            [7.37336608e-04, -9.07662912e-03, -3.52989039e-01],
            [6.87403157e-04, -1.07352039e-02, -4.17443337e-01],
            [6.35898649e-04, -1.00814746e-02, -4.30869272e-01],
            [5.41323028e-04, -9.15503419e-03, -4.20134259e-01],
            [6.20049481e-04, -8.88389211e-03, -4.29522383e-01],
            [7.59914717e-04, -8.23087339e-03, -2.76775661e-01],
            [9.42547999e-04, -9.27667821e-03, -2.45478502e-01],
            [8.92619238e-04, -9.15203017e-03, -2.75936722e-01],
            [5.55316025e-04, -8.44880553e-03, -3.72893523e-01],
            [6.99622902e-04, -9.72418503e-03, -3.67854550e-01],
            [9.26999083e-04, -9.54285039e-03, -3.56437446e-01],
            [7.46285501e-04, -9.59510319e-03, -4.04020278e-01],
            [6.98966497e-04, -9.18217038e-03, -3.44788137e-01],
            [7.24923434e-04, -9.74561445e-03, -3.54984268e-01],
            [6.94886332e-04, -8.68598330e-03, -3.09994448e-01],
            [5.75101519e-04, -1.02984677e-02, -4.86806641e-01],
            [7.42888895e-04, -8.65475149e-03, -2.76357905e-01],
            [8.20831857e-04, -9.02312986e-03, -3.39952238e-01],
            [8.40464019e-04, -8.73730328e-03, -3.06670057e-01],
            [8.65933014e-04, -8.24166123e-03, -2.59439267e-01],
            [7.93615116e-04, -8.37012564e-03, -2.78239365e-01],
            [7.03674434e-04, -8.35493228e-03, -3.39852366e-01],
            [1.01758171e-03, -9.76429488e-03, -3.25293501e-01],
            [9.35461643e-04, -1.10727071e-02, -3.71318233e-01],
            [6.50356170e-04, -7.95097870e-03, -3.03604382e-01],
            [4.52718863e-04, -9.29366489e-03, -4.00429804e-01],
            [8.05631243e-04, -1.14197185e-02, -4.20384510e-01],
            [8.47515847e-04, -9.96860929e-03, -2.83518131e-01],
            [9.12570519e-04, -9.82150144e-03, -2.74205006e-01],
            [7.24486639e-04, -9.54421710e-03, -4.09721956e-01],
            [7.23932102e-04, -1.04094786e-02, -3.73786611e-01],
            [9.42679935e-04, -8.91587002e-03, -2.68828559e-01],
            [8.36852693e-04, -8.79226358e-03, -2.95625460e-01],
            [8.97587536e-04, -8.53956249e-03, -2.67201130e-01],
            [9.73117996e-04, -9.68782968e-03, -2.96134055e-01],
            [8.40719658e-04, -1.01328839e-02, -3.08517753e-01],
            [8.05147795e-04, -1.02102599e-02, -3.34566563e-01],
            [7.32762366e-04, -9.63864513e-03, -3.28912512e-01],
            [7.35557680e-04, -8.56144135e-03, -3.32824489e-01],
            [5.64761519e-04, -9.41350536e-03, -3.82889705e-01],
            [7.76138219e-04, -8.73339354e-03, -3.67425635e-01],
            [9.04001056e-04, -8.75587360e-03, -2.95097747e-01],
            [7.80944234e-04, -8.41312502e-03, -2.79195850e-01],
            [5.84829947e-04, -8.74106854e-03, -3.96927577e-01],
            [8.07971547e-04, -9.72232905e-03, -3.65218918e-01],
            [6.78723776e-04, -9.04532922e-03, -4.08114414e-01],
            [6.22965250e-04, -7.88578923e-03, -2.90448743e-01],
            [7.32733656e-04, -1.07535959e-02, -4.32718154e-01],
            [6.58381190e-04, -8.92271654e-03, -4.33382136e-01],
            [1.00531649e-03, -1.07509792e-02, -3.62623849e-01],
            [8.46047026e-04, -1.01839205e-02, -3.76365227e-01],
            [6.34992512e-04, -8.29172112e-03, -3.25484728e-01],
            [6.54061167e-04, -9.22206494e-03, -3.78913040e-01],
            [5.87803756e-04, -1.07805999e-02, -4.07335809e-01],
            [9.60152672e-04, -1.05862197e-02, -3.86568926e-01],
            [6.64066442e-04, -8.43692458e-03, -2.98853767e-01],
            [7.68113218e-04, -8.38922581e-03, -2.77459781e-01],
            [6.58937372e-04, -8.64966745e-03, -3.63956511e-01],
            [6.32340903e-04, -8.96211469e-03, -4.03768612e-01],
            [8.72236883e-04, -1.04965281e-02, -4.06766513e-01],
            [7.54816369e-04, -1.03919294e-02, -4.31409504e-01],
            [8.06423935e-04, -1.01636581e-02, -3.88547102e-01],
            [7.08226377e-04, -1.09369096e-02, -3.66961189e-01],
            [7.23926909e-04, -8.27437331e-03, -2.95350247e-01],
            [9.30520307e-04, -9.84282871e-03, -2.94285838e-01],
            [5.76072684e-04, -1.01960346e-02, -4.04020147e-01],
            [7.75376154e-04, -8.88283504e-03, -3.13071129e-01],
            [7.55250526e-04, -9.05332451e-03, -3.17749478e-01],
            [5.38246196e-04, -1.01366458e-02, -4.41523867e-01],
        ]
    )
    return fitnesses


def test_dominate():
    assert not dominate(np.array([1, 2, 3]), np.array([1, 2, 3]))
    assert not dominate(np.array([1, 2, 3]), np.array([2, 3, 4]))
    assert dominate(np.array([2, 3, 4]), np.array([1, 2, 3]))
    assert dominate(np.array([2, 3, 4]), np.array([1, 2, 4]))
    assert not dominate(np.array([2, 3, 4]), np.array([1, 2, 5]))


def test_non_dominated_sort(fitnesses):
    res = non_denominated_sort(fitnesses=fitnesses, first_front_only=False)

    assert res == [
        [19, 20, 24, 36, 37, 50, 53, 55, 62, 64, 65, 72, 77],
        [51, 87, 94, 80, 14, 59, 38, 47, 49, 85, 11, 63, 9],
        [0, 73, 52, 82, 54, 41, 95, 58, 28, 13, 48],
        [6, 18, 27, 45, 69, 71, 86, 97, 2, 90, 66, 81, 75],
        [26, 39, 88, 5, 98, 67, 92],
        [21, 35, 74, 79, 89, 76, 3, 23, 25, 31, 42, 43, 68, 57, 91],
        [12, 16, 34, 60, 7, 17, 83, 10, 40, 44, 78],
        [1, 8, 15, 56, 70, 30, 61, 93, 96],
        [33, 22, 32, 84],
        [46, 99],
        [4, 29],
    ]

    res = non_denominated_sort(fitnesses=fitnesses, first_front_only=True)

    assert res == [[19, 20, 24, 36, 37, 50, 53, 55, 62, 64, 65, 72, 77]]
